from Utils import *

'''
DATA
'''

path = '../Social Bias Probing/Results/'

df = pd.read_csv(path+'SBIC-Pro-w-PPLs-PostProcess.csv')
identities = {}
stereotypes = {}
unique_categories = df['category'].unique()
num_categories = len(unique_categories)
for category in unique_categories:
    temp = pd.read_csv(path+category+'-identities-w-PPLs.csv')
    temp = temp.drop_duplicates(subset='identity')
    print(category,len(temp))
    identities[category] = temp.set_index('identity').to_dict(orient='index')
    stereotypes[category] = df[df['category'] == category]['stereotype'].unique() 
df = df.sort_values(by=['category'])
df = df.sort_values(by=['identity'])
for index, row in df.iterrows():
    category = row['category']
    identity = row['identity']
    for LM in LMs_columns:
        df.loc[index, LM] = df.loc[index, LM] / identities[category][identity][LM]
df.sort_index(ascending=True, inplace=True)
df.to_csv(path+'SBIC-Pro-w-Normalized-PPLs.csv', index=False)
df[LMs_columns] = df[LMs_columns].applymap(lambda x: np.log10(x))
df.to_csv(path+'SBIC-Pro-w-Normalized-Log10-PPLs.csv', index=False)

'''
FUNCTIONS
'''

def rank_variance(df, aggregated=False, variances=None): 
    res = {}
    if aggregated:
        for LM in LMs_columns:
            total_score = 0
            for category_scores in variances.values():
                total_score += category_scores[LM]
            mean = total_score / num_categories 
            res[LM] = mean
    else: 
        unique_ids = df['id'].unique()
        for LM in LMs_columns:
            for i in unique_ids:
                temp = []
                df_probe = df[df['id'] == i]
                temp.append(df_probe[LM].var())
            res[LM] = sum(temp) / len(temp)
    res = {key: round(value, 3) for key, value in res.items()}
    res = dict(sorted(res.items(), key=lambda item: item[1], reverse=True))
    print(res)
    return res 
            
def top(df, category, col, k=100, lowest=True):
    if lowest:
        print('\n <<< LOWEST >>>')
    else: 
        print('\n <<< HIGHEST >>>')
    res = {} 
    df_category = df[df['category'] == category]
    unique_ids = df_category['id'].unique()
    n_unique_ids = len(unique_ids)
    print('\n - PER MODEL -')
    if col == 'identity':
        for LM in LMs_columns:
            res[LM] = {identity: 0 for identity in identities[category].keys()}
            for i in unique_ids:
                df_probe = df_category[df_category['id'] == i]
                if lowest:
                    df_probe_sorted = df_probe.sort_values(by=[LM]) 
                else: 
                    df_probe_sorted = df_probe.sort_values(by=[LM], ascending=False)
                res[LM][df_probe_sorted.iloc[0][col]] += 1
            res[LM] = {key: round((value / n_unique_ids) * 100, 3) for key, value in res[LM].items()} 
            res[LM] = {key: value for key, value in res[LM].items() if value != 0}
            res[LM] = dict(sorted(res[LM].items(), key=lambda item: item[1], reverse=True))
            res[LM] = dict(list(res[LM].items())[:k]) if len(res[LM]) >= k else dict(res[LM])
            print(LM, res[LM])
    else: 
        agg_df = pd.DataFrame(columns=['id', 'category', 'identity', 'stereotype'] + LMs_columns)
        for i in unique_ids:
            df_probe = df_category[df_category['id'] == i]
            LMs_deltas = [df_probe[LM].max() - df_probe[LM].min() for LM in LMs_columns] # delta between max and min PPL scores
            agg_df.loc[i] = [df_probe['id'].iloc[0], df_probe['category'].iloc[0], df_probe['identity'].iloc[0], df_probe['stereotype'].iloc[0]] + LMs_deltas 
        for LM in LMs_columns:
            if lowest:
                df_probe_sorted = agg_df.sort_values(by=[LM]) 
            else:
                df_probe_sorted = agg_df.sort_values(by=[LM], ascending=False)
            res[LM] = {key: value for key, value in zip(df_probe_sorted[col][:k],round(df_probe_sorted[LM][:k],3))}
            print(LM, res[LM])
    print('\n - OVERLAP -')
    most_similar_models = similar_LMs(res)
    return res, most_similar_models 

def get_nonzero_keys(dictionary):
    return {key for key, value in dictionary.items() if value > 0}

def similar_LMs(lm_dictionaries):
    most_similar_models = {}
    for lm_name_i, lm_dict_i in lm_dictionaries.items():
        keys_i = get_nonzero_keys(lm_dict_i)
        most_similar_model = None
        max_overlap = -1
        for lm_name_j, lm_dict_j in lm_dictionaries.items():
            if lm_name_i == lm_name_j:
                continue
            keys_j = get_nonzero_keys(lm_dict_j)
            overlap = len(keys_i.intersection(keys_j))
            if overlap > max_overlap:
                max_overlap = overlap
                most_similar_model = lm_name_j
        if max_overlap > 0:
            most_similar_models[lm_name_i] = (most_similar_model, max_overlap)
    for lm_name, (similar_model, overlap) in most_similar_models.items():
        print(f"The most similar model to '{lm_name}' is '{similar_model}' with an overlap of {overlap} keys.")
    return most_similar_models

'''
CALLS
'''

print('\n\n\n\n ---- RANK W.R.T. VARIANCE ----') 
variances = {}
print('\n - PER CATEGORY -')
for category in unique_categories:
    print('\n' + category)
    df_category = df[df['category'] == category]
    variances[category] = rank_variance(df_category)
print('\n - AGGREGATED -')
rank_variance(df, True, variances)

data = []
for LM in LMs_columns:
    LM_variances = [LM]
    for category, variance in variances.items():
        LM_variances.append(variance[LM])
    data.append(LM_variances)
table2 = pd.DataFrame(data, columns=['Model', 'Culture', 'Gender', 'Disabled', 'Race'])
print(table2)
table2.to_csv(path+'Table2.csv', index=False)

print('\n\n\n\n ---- TOP ----') 
res_top_low_1_identities = {}
res_top_low_1_stereotypes = {}
res_top_high_1_identities = {}
res_top_high_1_stereotypes = {}
print('\n - PER CATEGORY -')
for category in unique_categories:
    print('\n' + category)
    res_top_low_1_identities[category], most_similar_models_ids = top(df, category, 'identity')
    res_top_low_1_stereotypes[category], most_similar_models_stereo = top(df, category, 'stereotype')
    res_top_high_1_identities[category] = top(df, category, 'identity', 100, False)
    res_top_high_1_stereotypes[category] = top(df, category, 'stereotype', 100, False)
with open(path+'most_similar_models_ids.json', 'w') as fp:
    json.dump(most_similar_models_ids, fp)
with open(path+'most_similar_models_stereo.json', 'w') as fp:
    json.dump(most_similar_models_stereo, fp)

data = []
for category, models in res_top_low_1_identities.items():
    for model, ids in models.items():
        ids = res_top_low_1_identities[category][model]
        for key, value in list(ids.items())[:3]:
           data.append([category, model, key, value])
table3id = pd.DataFrame(data, columns=['Category', 'Model', 'Identity', 'Identity Score'])
print(table3id)
table3id.to_csv(path+'Table3Id.csv', index=False)

data = []
for category, models in res_top_low_1_stereotypes.items(): 
    for model, stereotypes in models.items():
        stereotypes = res_top_low_1_stereotypes[category][model] 
        for key, value in list(stereotypes.items())[:3]:
            data.append([category, model, key, value])
table3stereo = pd.DataFrame(data, columns=['Category', 'Model', 'Stereotype', 'Stereotype Score'])
print(table3stereo)
table3stereo.to_csv(path+'Table3Stereo.csv', index=False)